import numpy as np
from sklearn.cluster import KMeans
from sklearn.cluster import kmeans_plusplus
from numpy.linalg import svd
import warnings
from sklearn.decomposition import TruncatedSVD
import torch

class Kz_Subspaces_algo:
    def __init__(self, k, z, dim, max_sgd,max_iter, lr=0.01):
        self.k = k
        self.dim = dim
        self.z = z
        self.centers = None
        self.lr = lr
        self.training_data = None
        self.history = None
        self.max_sgd = max_sgd
        self.max_iter = max_iter
        self.history = []
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.data_gpu = None
        self.subset_gpu = None
        self.use_subset = None
        print(f"Using device: {self.device}")
    
    def get_original_solution(self, data):
        self.use_subset = False
        self.data_gpu = torch.from_numpy(data).to(torch.float32).to(self.device)
        best_cost = None
        best_centers = None
        for i in range(10):
            self.fit(data)
            score = self.score(data)
            if best_cost == None or best_cost>score:
                best_cost = score
                best_centers = self.centers                 
        return best_centers, best_cost
    
    def get_subset_solution_original(self, dataset, subset_data):
        self.use_subset = True
        self.subset_gpu = torch.from_numpy(subset_data).to(torch.float32).to(self.device)
        self.fit(subset_data)
        subset_cost = self.score(subset_data)
        self.use_subset = False
        original_cost = self.score(dataset)
        return subset_cost, original_cost, self.centers 
        
    
    def check_convergence(self):
        lookback = 5
        if len(self.history) < lookback:
            return False
        for i in range(1, lookback+1):
            if self.history[-1]<0.98*self.history[-i]:
                return False
        print(f"early stopped with {self.history}")
        return True
    
    def fit(self, data):
        best_cost = 9999999
        best_centers = None
        self.history = []
        self.training_data = data
        self.centers = self.get_initial_subspaces(data)
        for i in range(self.max_iter):
            clusters, cost = self.assign_points_batch(data)
            if cost<best_cost:
                best_cost = cost
                best_centers = self.centers
            self.history.append(cost)
            if(self.check_convergence()):
                break
            for j in range(self.k):
                if clusters[j].shape[0] == 0:
                    continue
                new_center = self.find_best_subspace_iter(clusters[j], c=self.centers[j])
                self.centers[j] = self.orthogonalize_center(new_center)
        self.centers = best_centers
                
    def orthogonalize_center(self, center):
        return np.linalg.qr(center.T)[0].T
    def restart_center(self, data, min_dists):
        new_points = data[np.argsort(min_dists)[:self.dim]]
        svd_out = TruncatedSVD(n_components=self.dim).fit(new_points).components_

        return svd_out
    
    def find_best_subspace_iter(self, cluster, c):
        max_sgd = self.max_sgd
        center = torch.from_numpy(c).to(torch.float32).to(self.device)
        cluster_tensor = torch.from_numpy(cluster).to(torch.float32).to(self.device)
        center.requires_grad = True
        sgd = torch.optim.AdamW([center], lr=self.lr)
        for i in range(max_sgd):
            UUt = torch.matmul(center.T,center)
            diff = torch.eye(cluster_tensor.shape[1], device=self.device)-UUt
            proj = torch.matmul(diff,cluster_tensor.T)
            dist = torch.norm(proj, dim=0)**self.z
            loss = torch.sum(dist)
            loss.backward()
            sgd.step()
            sgd.zero_grad()
        return center.to('cpu').detach().numpy()
            
                
    def score(self, data):
        assert(self.centers is not None)
        clusters, cost = self.assign_points_batch(data)
        # cost = 0
        # for i in range(self.k):
        #     for point in data[clusters==i]:
        #         cost += self.calculate_cost(point, self.centers[i])
        return cost
    
    def assign_points_batch(self, data):
        dists, clusters = self.get_dists_to_nearest_center(data)
        cost = np.sum(dists)
        clusters_list = [data[clusters==i] for i in range(self.k)]
        return clusters_list, cost
    
    # def find_best_subspace(self, points):
    #     if(points.shape[0]<1):
    #         extra_points = self.training_data[np.random.choice(self.training_data.shape[0], 1, replace=False)]
    #         point = np.concatenate((points, extra_points))
    #         svd_out1 = TruncatedSVD(n_components=self.dim).fit(point).components_
    #         return svd_out1[0]
    #     svd_out1 = TruncatedSVD(n_components=self.dim).fit(points).components_
    #     # svd_out = svd(points)
    #     # picked = svd_out[2][:self.dim][0]
    #     return svd_out1#[0]
    #     # return picked
    
    def calculate_cost(self, point, center):
        # UUt = np.outer(center,center)
        UUt = center.T@center
        #Verify if used for more than one point.
        eye = np.eye(point.shape[0])
        diff = np.eye(point.shape[0]) - UUt
        proj = diff@point.T
        dist =  np.linalg.norm(proj, axis=0)**2
        cost = np.sum(dist)
        return cost
    
    
    def get_dists_to_nearest_center(self, data, centers=None):
        if centers is None:
            centers = self.centers
        dists_to_centers = torch.zeros((data.shape[0], len(centers)))
        if self.use_subset:
            data_gpu = self.subset_gpu
        else:
            data_gpu = self.data_gpu
        centers_gpu = torch.from_numpy(np.array(centers)).to(self.device)
        for i in range(len(centers)):
        # for i, center in enumerate(centers):
            # UUt = np.outer(center,center)
            center_fixed_dim = centers_gpu[i].reshape((self.dim, data_gpu.shape[1]))
            UUt = torch.matmul(center_fixed_dim.T,center_fixed_dim)
            diff = torch.eye(data_gpu.shape[1],device=self.device) - UUt
            proj = diff@data_gpu.T
            dist =  torch.linalg.norm(proj, axis=0)**self.z
            dists_to_centers[:,i] = dist
        clusters = torch.argmin(dists_to_centers, axis=1)
        min_dists, indices = torch.min(dists_to_centers, axis=1)
        return min_dists.numpy(), clusters.numpy()
    
    def silent_svd(self, dim, data):
        if data.shape[0] == 1:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                svd_out = TruncatedSVD(n_components=dim).fit(data).components_
            return svd_out
        else:
            svd_out = TruncatedSVD(n_components=dim).fit(data).components_
            return svd_out
        
    def get_initial_subspaces(self, data):
        centers = []
        centerpoints = data[np.random.choice(data.shape[0], self.dim, replace=False)]
        c = self.silent_svd(self.dim, centerpoints)
        centers.append(self.find_best_subspace_iter(centerpoints, c))
        for i in range(1, self.k):
            dists_to_nearest_center, _ = self.get_dists_to_nearest_center(data, centers)
            #sample with probability proportional to distance to nearest center
            centerpoints = data[np.random.choice(data.shape[0], self.dim, replace=False, p=dists_to_nearest_center/np.sum(dists_to_nearest_center))]
            c = self.silent_svd(self.dim, centerpoints)
            centers.append(self.find_best_subspace_iter(centerpoints,c))
        return np.array(centers)
    
    # def get_initial_subspaces(self, data):
    #     random_groups = []
    #     for i in range(self.k):
    #         random_groups.append(data[np.random.choice(data.shape[0], self.dim,replace=False)])
    #     subspaces1 = []
    #     for center_points in random_groups:
    #         svd_out = TruncatedSVD(n_components=self.dim).fit(center_points).components_
    #         subspaces1.append(svd_out)
        
    #     subspaces = np.array(subspaces1)        
    #     return subspaces